/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "SSHProtocol.h"
#include <iomanip>

namespace aiengine {

bool SSHProtocol::is_minimal_ssh_header(const uint8_t *hdr) {

        if ((hdr[0] == 'S')and(hdr[1] == 'S')and(hdr[2] == 'H')and
                (hdr[3] == '-')and(hdr[4] == '2')) {
        	return true;
        }
        return false;
}

bool SSHProtocol::check(const Packet &packet) {

        int length = packet.getLength();

        if (length >= 8) {
                setHeader(packet.getPayload());

                if (is_minimal_ssh_header(header_)) {
                        ++total_valid_packets_;
                        return true;
                }
        }
        ++total_invalid_packets_;
        return false;
}

void SSHProtocol::setDynamicAllocatedMemory(bool value) {

        info_cache_->setDynamicAllocatedMemory(value);
        name_cache_->setDynamicAllocatedMemory(value);
}

bool SSHProtocol::isDynamicAllocatedMemory() const {

        return info_cache_->isDynamicAllocatedMemory();
}

uint64_t SSHProtocol::getCurrentUseMemory() const {

        uint64_t mem = sizeof(SSHProtocol);

        mem += info_cache_->getCurrentUseMemory();
        mem += name_cache_->getCurrentUseMemory();

        return mem;
}

uint64_t SSHProtocol::getAllocatedMemory() const {

        uint64_t mem = sizeof(SSHProtocol);

        mem += info_cache_->getAllocatedMemory();
        mem += name_cache_->getAllocatedMemory();

        return mem;
}

uint64_t SSHProtocol::getTotalAllocatedMemory() const {

	return getAllocatedMemory();
}

uint32_t SSHProtocol::getTotalCacheMisses() const {

        uint32_t miss = 0;

        miss = info_cache_->getTotalFails();
        miss += name_cache_->getTotalFails();

        return miss;
}

uint64_t SSHProtocol::compute_memory_used_by_maps() const {

        uint64_t bytes = 0;

        std::for_each (name_map_.begin(), name_map_.end(), [&bytes] (PairStringCacheHits const &f) {
                bytes += f.first.size();
        });
        return bytes;
}

void SSHProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

		int64_t total_cache_bytes_released = compute_memory_used_by_maps();
                int64_t total_bytes_released_by_flows = 0;
                int64_t total_cache_save_bytes = 0;
                int32_t release_flows = 0;
		int32_t release_name = name_map_.size();

                for (auto &flow: ft) {
                        if (SharedPointer<SSHInfo> info = flow->getSSHInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                ++release_flows;
                                flow->layer7info.reset();
                                info_cache_->release(info);
                        }
                }
                // Some entries can be still on the maps and needs to be
                // retrieve to their existing caches
                for (auto &entry: name_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(name_cache_, entry.second.sc);
		}
		name_map_.clear();

		data_time_ = boost::posix_time::microsec_clock::local_time();

                msg.str("");
                msg << "Release " << release_name << " names, " << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
                infoMessage(msg.str());
        }
}

void SSHProtocol::releaseFlowInfo(Flow *flow) {

	if (auto info = flow->getSSHInfo(); info)
		info_cache_->release(info);
}

void SSHProtocol::attach_ssh_client_name(SSHInfo *info, const boost::string_ref &name) {

        if (!info->client_name) {
                if (StringMap::iterator it = name_map_.find(name); it != name_map_.end()) {
                        ++(it->second).hits;
                        info->client_name = (it->second).sc;
		} else {
                        if (SharedPointer<StringCache> name_ptr = name_cache_->acquire(); name_ptr) {
                                name_ptr->name(name.data(), name.length());
                                info->client_name = name_ptr;
                                name_map_.insert(std::make_pair(name_ptr->name(), name_ptr));
                        }
                }
        }
}

void SSHProtocol::attach_ssh_server_name(SSHInfo *info, const boost::string_ref &name) {

        if (!info->server_name) {
                if (StringMap::iterator it = name_map_.find(name); it != name_map_.end()) {
                        ++(it->second).hits;
                        info->server_name = (it->second).sc;
		} else {
                        if (SharedPointer<StringCache> name_ptr = name_cache_->acquire(); name_ptr) {
                                name_ptr->name(name.data(), name.length());
                                info->server_name = name_ptr;
                                name_map_.insert(std::make_pair(name_ptr->name(), name_ptr));
                        }
                }
        }
}

void SSHProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	int length = flow->packet->getLength();
	total_bytes_ += length;
	++total_packets_;
	++flow->total_packets_l7;

	current_flow_ = flow;

        SharedPointer<SSHInfo> info = flow->getSSHInfo();
        if (!info) {
                if (info = info_cache_->acquire(); !info) {
			logFailCache(info_cache_->name(), flow);
                        return;
                }
                flow->layer7info = info;
        }

#ifdef DEBUG
	std::cout << __FILE__ << ":" << __func__ << ":" << *flow << " pkts:" << flow->total_packets << std::endl;
#endif

	if (length >= header_size) {
		if (flow->total_packets_l7 > 2) { // Client and server hello done
			if (info->isHandshake()) {
				const uint8_t *payload = flow->packet->getPayload();
				int32_t offset = 0;

				do {
					const ssh_header *hdr = reinterpret_cast<const ssh_header*>(&payload[offset]);

					uint32_t len = ntohl(hdr->length);
					int8_t msg_type = (int8_t)hdr->data[0];

					offset += len + sizeof(ssh_header) - 1;

					++total_handshake_pdus_;

					if ((msg_type >= 20)and(msg_type <= 29)) {
						++total_algorithm_negotiation_messages_;
					} else if ((msg_type >= 30)and(msg_type <= 49)) {
						++total_key_exchange_messages_;
					} else {
						++total_others_;
					}

					if (msg_type == 21) { // New keys
						if (flow->getFlowDirection() == FlowDirection::FORWARD)
							info->setClientHandshake(false);
						else
							info->setServerHandshake(false);
					}

				} while ((offset < length)and(offset > 0));
			} else {
				// The flow is on encryption mode :)
				info->addEncryptedBytes(length);
				total_encrypted_bytes_ += length;
				++total_encrypted_packets_;
			}
		} else {
			const char *data = reinterpret_cast<const char*>(flow->packet->getPayload());
			boost::string_ref name(data, flow->packet->getLength() - 2);
			if (flow->total_packets_l7 == 1) // The server sending the first data payload
				attach_ssh_server_name(info.get(), name);
			else
				attach_ssh_client_name(info.get(), name);
		}
	}
}

void SSHProtocol::increaseAllocatedMemory(int value) {

        info_cache_->create(value);
        name_cache_->create(value);
}

void SSHProtocol::decreaseAllocatedMemory(int value) {

        info_cache_->destroy(value);
        name_cache_->destroy(value);
}

void SSHProtocol::statistics(std::basic_ostream<char>& out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

	if (level > 3) {
		out << "\t" << "Total encrypted bytes:  " << std::setw(10) << total_encrypted_bytes_ << "\n"
			<< "\t" << "Total encrypted packets:" << std::setw(10) << total_encrypted_packets_ << "\n"
			<< "\t" << "Total other packets:    " << std::setw(10) << total_others_ << std::endl;
	}
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		name_cache_->statistics(out);
		if (level > 4)
			name_map_.show(out, "\t", limit);
	}
}

void SSHProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
                out["encrypted_bytes"] = total_encrypted_bytes_;
                out["encrypted_packets"] = total_encrypted_packets_;
        }
}

CounterMap SSHProtocol::getCounters() const {
       	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);
        cm.addKeyValue("encrypted bytes", total_encrypted_bytes_);
        cm.addKeyValue("encrypted packets", total_encrypted_packets_);

        return cm;
}

#if defined(PYTHON_BINDING) || defined(RUBY_BINDING)
#if defined(PYTHON_BINDING)
boost::python::dict SSHProtocol::getCacheData(const std::string &name) const {
#elif defined(RUBY_BINDING)
VALUE SSHProtocol::getCacheData(const std::string &name) const {
#endif
        if (boost::iequals(name, "name"))
        	return addMapToHash(name_map_);

	StringMap empty {"", ""};

        return addMapToHash(empty);
}

#if defined(PYTHON_BINDING)
SharedPointer<Cache<StringCache>> SSHProtocol::getCache(const std::string &name) {

        if (boost::iequals(name, "name"))
                return name_cache_;

        return nullptr;
}

#endif

#endif

void SSHProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

        if (boost::iequals(map_name, "names")) {
                name_map_.show(out, limit);
        }
}

void SSHProtocol::resetCounters() {

	reset();

        total_encrypted_bytes_ = 0;
        total_encrypted_packets_ = 0;
        total_handshake_pdus_ = 0;
        total_algorithm_negotiation_messages_ = 0;
        total_key_exchange_messages_ = 0;
        total_others_ = 0;
}

} // namespace aiengine
